package com.duan.laugh.msg.config;

import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.duan.laugh.common.core.constants.CoreConstants;
import com.duan.laugh.msg.api.constans.MsgConstants;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.RemoteTokenServices;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;

/**
 * websocket配置
 * 注解：@EnableWebSocketMessageBroker 表示开启使用STOMP协议来传输基于代理的消息，Broker就是代理的意
 * @author duanjw
 */
@Slf4j
@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {

    @Autowired
    private RemoteTokenServices tokenService;

    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {
        // 1. 注册STOMP协议节点ws
        // 2. 允许跨域
        // 3. 使用sockJS
        registry.addEndpoint("/ws").addInterceptors()
                .setAllowedOrigins("*")
                .withSockJS();
    }

    @Override
    public void configureMessageBroker(MessageBrokerRegistry registry) {
        // 启用广播
        registry.enableSimpleBroker(MsgConstants.TOPIC,"/user");
    }

    @Override
    public void configureClientInboundChannel(ChannelRegistration registration) {
        registration.interceptors(new ChannelInterceptor() {
            @Override
            public Message<?> preSend(Message<?> message, MessageChannel channel) {
                StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
                // 判断是否首次连接请求
                if (StompCommand.CONNECT.equals(accessor.getCommand())) {
                    String tokens = accessor.getFirstNativeHeader("Authorization");
                    log.info("webSocket token is {}", tokens);
                    if (StrUtil.isBlank(tokens)) {
                        return null;
                    }
                    // 验证令牌信息
                    OAuth2Authentication auth2Authentication = tokenService.loadAuthentication(tokens.split(" ")[CoreConstants.ONE]);
                    if (ObjectUtil.isNotNull(auth2Authentication)) {
                        String clientId = auth2Authentication.getOAuth2Request().getClientId();
                        SecurityContextHolder.getContext().setAuthentication(auth2Authentication);
                        accessor.setUser(() -> auth2Authentication.getName());
                        return message;
                    } else {
                        return null;
                    }
                }
                //不是首次连接，已经成功登陆
                return message;
            }
        });
    }
}
